import os
import pickle as pkl
import shutil
from dataclasses import dataclass
from itertools import chain, islice, tee
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
)

import torch
from torch import Tensor
from tqdm import tqdm

from torch_geometric.data import Data
from torch_geometric.io import fs
from torch_geometric.typing import WITH_PT24

# Could be any hashable type
TripletLike = Tuple[str, str, str]

KnowledgeGraphLike = Iterable[TripletLike]


def ordered_set(values: Iterable[str]) -> List[str]:
    return list(dict.fromkeys(values))


# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?

NODE_PID = "pid"  # Encodes node id

NODE_KEYS = {NODE_PID}

EDGE_PID = "e_pid"  # Encodes source node, relation, destination node
EDGE_HEAD = "h"  # Encodes source node
EDGE_RELATION = "r"  # Encodes relation
EDGE_TAIL = "t"  # Encodes destination node
EDGE_INDEX = "edge_idx"  # Encodes source node, destination node

EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}

FeatureValueType = Union[Sequence[Any], Tensor]


@dataclass
class MappedFeature:
    name: str
    values: FeatureValueType

    def __eq__(self, value: "MappedFeature") -> bool:
        eq = self.name == value.name
        if isinstance(self.values, torch.Tensor):
            eq &= torch.equal(self.values, value.values)
        else:
            eq &= self.values == value.values
        return eq


if WITH_PT24:
    torch.serialization.add_safe_globals([MappedFeature])


class LargeGraphIndexer:
    """For a dataset that consists of multiple subgraphs that are assumed to
    be part of a much larger graph, collate the values into a large graph store
    to save resources.
    """
    def __init__(
        self,
        nodes: Iterable[str],
        edges: KnowledgeGraphLike,
        node_attr: Optional[Dict[str, List[Any]]] = None,
        edge_attr: Optional[Dict[str, List[Any]]] = None,
    ) -> None:
        r"""Constructs a new index that uniquely catalogs each node and edge
        by id. Not meant to be used directly.

        Args:
            nodes (Iterable[str]): Node ids in the graph.
            edges (KnowledgeGraphLike): Edge ids in the graph.
                Example: [("cats", "eat", "dogs")]
            node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
                attribute name and list of their values in order of unique node
                ids. Defaults to None.
            edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge
                attribute name and list of their values in order of unique edge
                ids. Defaults to None.
        """
        self._nodes: Dict[str, int] = dict()
        self._edges: Dict[TripletLike, int] = dict()

        self._mapped_node_features: Set[str] = set()
        self._mapped_edge_features: Set[str] = set()

        if len(nodes) != len(set(nodes)):
            raise AttributeError("Nodes need to be unique")
        if len(edges) != len(set(edges)):
            raise AttributeError("Edges need to be unique")

        if node_attr is not None:
            # TODO: Validity checks btw nodes and node_attr
            self.node_attr = node_attr
            if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:
                raise AttributeError(
                    "Invalid node_attr object. Missing " +
                    f"{NODE_KEYS - set(self.node_attr.keys())}")
            elif self.node_attr[NODE_PID] != nodes:
                raise AttributeError(
                    "Nodes provided do not match those in node_attr")
        else:
            self.node_attr = dict()
            self.node_attr[NODE_PID] = nodes

        for i, node in enumerate(self.node_attr[NODE_PID]):
            self._nodes[node] = i

        if edge_attr is not None:
            # TODO: Validity checks btw edges and edge_attr
            self.edge_attr = edge_attr

            if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:
                raise AttributeError(
                    "Invalid edge_attr object. Missing " +
                    f"{EDGE_KEYS - set(self.edge_attr.keys())}")
            elif self.node_attr[EDGE_PID] != edges:
                raise AttributeError(
                    "Edges provided do not match those in edge_attr")

        else:
            self.edge_attr = dict()
            for default_key in EDGE_KEYS:
                self.edge_attr[default_key] = list()
            self.edge_attr[EDGE_PID] = edges

            for tup in edges:
                h, r, t = tup
                self.edge_attr[EDGE_HEAD].append(h)
                self.edge_attr[EDGE_RELATION].append(r)
                self.edge_attr[EDGE_TAIL].append(t)
                self.edge_attr[EDGE_INDEX].append(
                    (self._nodes[h], self._nodes[t]))
        for i, tup in enumerate(edges):
            self._edges[tup] = i

    @classmethod
    def from_triplets(
        cls,
        triplets: KnowledgeGraphLike,
        pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
    ) -> "LargeGraphIndexer":
        r"""Generate a new index from a series of triplets that represent edge
        relations between nodes.
        Formatted like (source_node, edge, dest_node).

        Args:
            triplets (KnowledgeGraphLike): Series of triplets representing
                knowledge graph relations. Example: [("cats", "eat", dogs")].
                Note: Please ensure triplets are unique.
            pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
                Optional preprocessing function to apply to triplets.
                Defaults to None.

        Returns:
            LargeGraphIndexer: Index of unique nodes and edges.
        """
        # NOTE: Right now assumes that all trips can be loaded into memory
        nodes = []
        edges = []

        if pre_transform is not None:

            def apply_transform(
                    trips: KnowledgeGraphLike) -> Iterator[TripletLike]:
                for trip in trips:
                    yield pre_transform(trip)

            triplets = list(apply_transform(triplets))

        for h, r, t in triplets:

            for node in (h, t):
                nodes.append(node)

            edge_idx = (h, r, t)
            edges.append(edge_idx)
        nodes = ordered_set(nodes)
        edges = ordered_set(edges)
        return cls(list(nodes), list(edges))

    @classmethod
    def collate(cls,
                graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer":
        r"""Combines a series of large graph indexes into a single large graph
        index.

        Args:
            graphs (Iterable[LargeGraphIndexer]): Indices to be
                combined.

        Returns:
            LargeGraphIndexer: Singular unique index for all nodes and edges
                in input indices.
        """
        # FIXME Needs to merge node attrs and edge attrs?
        trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
        return cls.from_triplets(trips)

    def get_unique_node_features(self,
                                 feature_name: str = NODE_PID) -> List[str]:
        r"""Get all the unique values for a specific node attribute.

        Args:
            feature_name (str, optional): Name of feature to get.
                Defaults to NODE_PID.

        Returns:
            List[str]: List of unique values for the specified feature.
        """
        try:
            if feature_name in self._mapped_node_features:
                raise IndexError(
                    "Only non-mapped features can be retrieved uniquely.")
            return ordered_set(self.get_node_features(feature_name))

        except KeyError as e:
            raise AttributeError(
                f"Nodes do not have a feature called {feature_name}") from e

    def add_node_feature(
        self,
        new_feature_name: str,
        new_feature_vals: FeatureValueType,
        map_from_feature: str = NODE_PID,
    ) -> None:
        r"""Adds a new feature that corresponds to each unique node in
            the graph.

        Args:
            new_feature_name (str): Name to call the new feature.
            new_feature_vals (FeatureValueType): Values to map for that
                new feature.
            map_from_feature (str, optional): Key of feature to map from.
                Size must match the number of feature values.
                Defaults to NODE_PID.
        """
        if new_feature_name in self.node_attr:
            raise AttributeError("Features cannot be overridden once created")
        if map_from_feature in self._mapped_node_features:
            raise AttributeError(
                f"{map_from_feature} is already a feature mapping.")

        feature_keys = self.get_unique_node_features(map_from_feature)
        if len(feature_keys) != len(new_feature_vals):
            raise AttributeError(
                "Expected encodings for {len(feature_keys)} unique features," +
                f" but got {len(new_feature_vals)} encodings.")

        if map_from_feature == NODE_PID:
            self.node_attr[new_feature_name] = new_feature_vals
        else:
            self.node_attr[new_feature_name] = MappedFeature(
                name=map_from_feature, values=new_feature_vals)
            self._mapped_node_features.add(new_feature_name)

    def get_node_features(
        self,
        feature_name: str = NODE_PID,
        pids: Optional[Iterable[str]] = None,
    ) -> List[Any]:
        r"""Get node feature values for a given set of unique node ids.
            Returned values are not necessarily unique.

        Args:
            feature_name (str, optional): Name of feature to fetch. Defaults
                to NODE_PID.
            pids (Optional[Iterable[str]], optional): Node ids to fetch
                for. Defaults to None, which fetches all nodes.

        Returns:
            List[Any]: Node features corresponding to the specified ids.
        """
        if feature_name in self._mapped_node_features:
            values = self.node_attr[feature_name].values
        else:
            values = self.node_attr[feature_name]
        # TODO: torch_geometric.utils.select
        if isinstance(values, torch.Tensor):
            idxs = list(
                self.get_node_features_iter(feature_name, pids,
                                            index_only=True))
            return values[torch.tensor(idxs).long()]
        return list(self.get_node_features_iter(feature_name, pids))

    def get_node_features_iter(
        self,
        feature_name: str = NODE_PID,
        pids: Optional[Iterable[str]] = None,
        index_only: bool = False,
    ) -> Iterator[Any]:
        """Iterator version of get_node_features. If index_only is True,
        yields indices instead of values.
        """
        if pids is None:
            pids = self.node_attr[NODE_PID]

        if feature_name in self._mapped_node_features:
            feature_map_info = self.node_attr[feature_name]
            from_feature_name, to_feature_vals = (
                feature_map_info.name,
                feature_map_info.values,
            )
            from_feature_vals = self.get_unique_node_features(
                from_feature_name)
            feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}

            for pid in pids:
                idx = self._nodes[pid]
                from_feature_val = self.node_attr[from_feature_name][idx]
                to_feature_idx = feature_mapping[from_feature_val]
                if index_only:
                    yield to_feature_idx
                else:
                    yield to_feature_vals[to_feature_idx]
        else:
            for pid in pids:
                idx = self._nodes[pid]
                if index_only:
                    yield idx
                else:
                    yield self.node_attr[feature_name][idx]

    def get_unique_edge_features(self,
                                 feature_name: str = EDGE_PID) -> List[str]:
        r"""Get all the unique values for a specific edge attribute.

        Args:
            feature_name (str, optional): Name of feature to get.
                Defaults to EDGE_PID.

        Returns:
            List[str]: List of unique values for the specified feature.
        """
        try:
            if feature_name in self._mapped_edge_features:
                raise IndexError(
                    "Only non-mapped features can be retrieved uniquely.")
            return ordered_set(self.get_edge_features(feature_name))
        except KeyError as e:
            raise AttributeError(
                f"Edges do not have a feature called {feature_name}") from e

    def add_edge_feature(
        self,
        new_feature_name: str,
        new_feature_vals: FeatureValueType,
        map_from_feature: str = EDGE_PID,
    ) -> None:
        r"""Adds a new feature that corresponds to each unique edge in
        the graph.

        Args:
            new_feature_name (str): Name to call the new feature.
            new_feature_vals (FeatureValueType): Values to map for that new
                feature.
            map_from_feature (str, optional): Key of feature to map from.
                Size must match the number of feature values.
                Defaults to EDGE_PID.
        """
        if new_feature_name in self.edge_attr:
            raise AttributeError("Features cannot be overridden once created")
        if map_from_feature in self._mapped_edge_features:
            raise AttributeError(
                f"{map_from_feature} is already a feature mapping.")

        feature_keys = self.get_unique_edge_features(map_from_feature)
        if len(feature_keys) != len(new_feature_vals):
            raise AttributeError(
                f"Expected encodings for {len(feature_keys)} unique features, "
                + f"but got {len(new_feature_vals)} encodings.")

        if map_from_feature == EDGE_PID:
            self.edge_attr[new_feature_name] = new_feature_vals
        else:
            self.edge_attr[new_feature_name] = MappedFeature(
                name=map_from_feature, values=new_feature_vals)
            self._mapped_edge_features.add(new_feature_name)

    def get_edge_features(
        self,
        feature_name: str = EDGE_PID,
        pids: Optional[Iterable[str]] = None,
    ) -> List[Any]:
        r"""Get edge feature values for a given set of unique edge ids.
            Returned values are not necessarily unique.

        Args:
            feature_name (str, optional): Name of feature to fetch.
                Defaults to EDGE_PID.
            pids (Optional[Iterable[str]], optional): Edge ids to fetch
                for. Defaults to None, which fetches all edges.

        Returns:
            List[Any]: Node features corresponding to the specified ids.
        """
        if feature_name in self._mapped_edge_features:
            values = self.edge_attr[feature_name].values
        else:
            values = self.edge_attr[feature_name]

        # TODO: torch_geometric.utils.select
        if isinstance(values, torch.Tensor):
            idxs = list(
                self.get_edge_features_iter(feature_name, pids,
                                            index_only=True))
            return values[torch.tensor(idxs).long()]
        return list(self.get_edge_features_iter(feature_name, pids))

    def get_edge_features_iter(
        self,
        feature_name: str = EDGE_PID,
        pids: Optional[KnowledgeGraphLike] = None,
        index_only: bool = False,
    ) -> Iterator[Any]:
        """Iterator version of get_edge_features. If index_only is True,
        yields indices instead of values.
        """
        if pids is None:
            pids = self.edge_attr[EDGE_PID]

        if feature_name in self._mapped_edge_features:
            feature_map_info = self.edge_attr[feature_name]
            from_feature_name, to_feature_vals = (
                feature_map_info.name,
                feature_map_info.values,
            )
            from_feature_vals = self.get_unique_edge_features(
                from_feature_name)
            feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}

            for pid in pids:
                idx = self._edges[pid]
                from_feature_val = self.edge_attr[from_feature_name][idx]
                to_feature_idx = feature_mapping[from_feature_val]
                if index_only:
                    yield to_feature_idx
                else:
                    yield to_feature_vals[to_feature_idx]
        else:
            for pid in pids:
                idx = self._edges[pid]
                if index_only:
                    yield idx
                else:
                    yield self.edge_attr[feature_name][idx]

    def to_triplets(self) -> Iterator[TripletLike]:
        return iter(self.edge_attr[EDGE_PID])

    def save(self, path: str) -> None:
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path, exist_ok=True)
        with open(path + "/edges", "wb") as f:
            pkl.dump(self._edges, f)
        with open(path + "/nodes", "wb") as f:
            pkl.dump(self._nodes, f)

        with open(path + "/mapped_edges", "wb") as f:
            pkl.dump(self._mapped_edge_features, f)
        with open(path + "/mapped_nodes", "wb") as f:
            pkl.dump(self._mapped_node_features, f)

        node_attr_path = path + "/node_attr"
        os.makedirs(node_attr_path, exist_ok=True)
        for attr_name, vals in self.node_attr.items():
            torch.save(vals, node_attr_path + f"/{attr_name}.pt")

        edge_attr_path = path + "/edge_attr"
        os.makedirs(edge_attr_path, exist_ok=True)
        for attr_name, vals in self.edge_attr.items():
            torch.save(vals, edge_attr_path + f"/{attr_name}.pt")

    @classmethod
    def from_disk(cls, path: str) -> "LargeGraphIndexer":
        indexer = cls(list(), list())
        with open(path + "/edges", "rb") as f:
            indexer._edges = pkl.load(f)
        with open(path + "/nodes", "rb") as f:
            indexer._nodes = pkl.load(f)

        with open(path + "/mapped_edges", "rb") as f:
            indexer._mapped_edge_features = pkl.load(f)
        with open(path + "/mapped_nodes", "rb") as f:
            indexer._mapped_node_features = pkl.load(f)

        node_attr_path = path + "/node_attr"
        for fname in os.listdir(node_attr_path):
            full_fname = f"{node_attr_path}/{fname}"
            key = fname.split(".")[0]
            indexer.node_attr[key] = fs.torch_load(full_fname)

        edge_attr_path = path + "/edge_attr"
        for fname in os.listdir(edge_attr_path):
            full_fname = f"{edge_attr_path}/{fname}"
            key = fname.split(".")[0]
            indexer.edge_attr[key] = fs.torch_load(full_fname)

        return indexer

    def to_data(self, node_feature_name: str,
                edge_feature_name: Optional[str] = None) -> Data:
        """Return a Data object containing all the specified node and
            edge features and the graph.

        Args:
            node_feature_name (str): Feature to use for nodes
            edge_feature_name (Optional[str], optional): Feature to use for
                edges. Defaults to None.

        Returns:
            Data: Data object containing the specified node and
                edge features and the graph.
        """
        x = torch.Tensor(self.get_node_features(node_feature_name))
        node_id = torch.LongTensor(range(len(x)))
        edge_index = torch.t(
            torch.LongTensor(self.get_edge_features(EDGE_INDEX)))

        edge_attr = (self.get_edge_features(edge_feature_name)
                     if edge_feature_name is not None else None)
        edge_id = torch.LongTensor(range(len(edge_attr)))

        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
                    edge_id=edge_id, node_id=node_id)

    def __eq__(self, value: "LargeGraphIndexer") -> bool:
        eq = True
        eq &= self._nodes == value._nodes
        eq &= self._edges == value._edges
        eq &= self.node_attr.keys() == value.node_attr.keys()
        eq &= self.edge_attr.keys() == value.edge_attr.keys()
        eq &= self._mapped_node_features == value._mapped_node_features
        eq &= self._mapped_edge_features == value._mapped_edge_features

        for k in self.node_attr:
            eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))
            if isinstance(self.node_attr[k], torch.Tensor):
                eq &= torch.equal(self.node_attr[k], value.node_attr[k])
            else:
                eq &= self.node_attr[k] == value.node_attr[k]
        for k in self.edge_attr:
            eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))
            if isinstance(self.edge_attr[k], torch.Tensor):
                eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])
            else:
                eq &= self.edge_attr[k] == value.edge_attr[k]
        return eq


def get_features_for_triplets_groups(
    indexer: LargeGraphIndexer,
    triplet_groups: Iterable[KnowledgeGraphLike],
    node_feature_name: str = "x",
    edge_feature_name: str = "edge_attr",
    pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,
    verbose: bool = False,
    max_batch_size: int = 250,
    num_workers: Optional[int] = None,
) -> Iterator[Data]:
    """Given an indexer and a series of triplet groups (like a dataset),
    retrieve the specified node and edge features for each triplet from the
    index.

    Args:
        indexer (LargeGraphIndexer): Indexer containing desired features
        triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of
            triplets to fetch features for
        node_feature_name (str, optional): Node feature to fetch.
            Defaults to "x".
        edge_feature_name (str, optional): edge feature to fetch.
            Defaults to "edge_attr".
        pre_transform (Callable[[TripletLike], TripletLike]):
            Optional preprocessing to perform on triplets.
            Defaults to None.
        verbose (bool, optional): Whether to print progress.
            Defaults to False.
        max_batch_size (int, optional):
            Maximum batch size for fetching features.
            Defaults to 250.
        num_workers (int, optional):
            Number of workers to use for fetching features.
            Defaults to None (all available).

    Yields:
        Iterator[Data]: For each triplet group, yield a data object containing
            the unique graph and features from the index.
    """
    def apply_transform(trips: Iterable[TripletLike]) -> Iterator[TripletLike]:
        for trip in trips:
            yield pre_transform(tuple(trip))

    # Carefully trying to avoid loading all triplets into memory at once
    # While also still tracking the number of elements for tqdm
    triplet_groups: List[Iterator[TripletLike]] = [
        apply_transform(triplets) for triplets in triplet_groups
    ]

    node_keys = []
    edge_keys = []
    edge_index = []
    """
    For each KG, we gather the node_indices, edge_keys,
    and edge_indices needed to construct each Data object
    """

    for kg_triplets in tqdm(triplet_groups, disable=not verbose):
        kg_triplets_nodes, kg_triplets_edge_keys, kg_triplets_edge_index = tee(
            kg_triplets, 3)
        """
        Don't apply pre_transform here,
        because it has already been applied on the triplet groups/
        """
        small_graph_indexer = LargeGraphIndexer.from_triplets(
            kg_triplets_nodes)

        node_keys.append(small_graph_indexer.get_node_features())
        edge_keys.append(
            small_graph_indexer.get_edge_features(pids=kg_triplets_edge_keys))
        edge_index.append(
            small_graph_indexer.get_edge_features(
                EDGE_INDEX,
                kg_triplets_edge_index,
            ))
    """
    We get the embeddings for each node and edge key in the KG,
    but we need to do so in batches.
    Batches that are too small waste compute time,
    as each call to get features has an upfront cost.
    Batches that are too large waste memory,
    as we need to store all the result embeddings in memory.
    """

    def _fetch_feature_batch(batches):
        node_key_batch, edge_key_batch, edge_index_batch = batches
        node_feats = indexer.get_node_features(
            feature_name=node_feature_name,
            pids=chain.from_iterable(node_key_batch))
        edge_feats = indexer.get_edge_features(
            feature_name=edge_feature_name,
            pids=chain.from_iterable(edge_key_batch))

        last_node_idx, last_edge_idx = 0, 0
        for (nkeys, ekeys, eidx) in zip(node_key_batch, edge_key_batch,
                                        edge_index_batch):
            nlen, elen = len(nkeys), len(ekeys)
            x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
            last_node_idx += len(nkeys)

            edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
                                                elen])
            last_edge_idx += len(ekeys)

            edge_idx = torch.LongTensor(eidx).T

            data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
            data_obj[NODE_PID] = node_keys
            data_obj[EDGE_PID] = edge_keys
            data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
            data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]

            yield data_obj

    # NOTE: Backport of itertools.batched from Python 3.12
    def batched(iterable, n, *, strict=False):
        # batched('ABCDEFG', 3) → ABC DEF G
        if n < 1:
            raise ValueError('n must be at least one')
        iterator = iter(iterable)
        while batch := tuple(islice(iterator, n)):
            if strict and len(batch) != n:
                raise ValueError('batched(): incomplete batch')
            yield batch

    import multiprocessing as mp
    import multiprocessing.pool as mpp
    num_workers = num_workers if num_workers is not None else mp.cpu_count()
    ideal_batch_size = min(max_batch_size,
                           max(1,
                               len(triplet_groups) // num_workers))

    node_key_batches = batched(node_keys, ideal_batch_size)
    edge_key_batches = batched(edge_keys, ideal_batch_size)
    edge_index_batches = batched(edge_index, ideal_batch_size)
    batches = zip(node_key_batches, edge_key_batches, edge_index_batches)

    with mpp.ThreadPool() as pool:
        result = pool.map(_fetch_feature_batch, batches)
    yield from chain.from_iterable(result)


def get_features_for_triplets(
    indexer: LargeGraphIndexer,
    triplets: KnowledgeGraphLike,
    node_feature_name: str = "x",
    edge_feature_name: str = "edge_attr",
    pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,
    verbose: bool = False,
) -> Data:
    """For a given set of triplets retrieve a Data object containing the
        unique graph and features from the index.

    Args:
        indexer (LargeGraphIndexer): Indexer containing desired features
        triplets (KnowledgeGraphLike): Triplets to fetch features for
        node_feature_name (str, optional): Feature to use for node features.
            Defaults to "x".
        edge_feature_name (str, optional): Feature to use for edge features.
            Defaults to "edge_attr".
        pre_transform (Callable[[TripletLike], TripletLike]):
            Optional preprocessing function for triplets. Defaults to None.
        verbose (bool, optional): Whether to print progress. Defaults to False.

    Returns:
        Data: Data object containing the unique graph and features from the
            index for the given triplets.
    """
    gen = get_features_for_triplets_groups(indexer, [triplets],
                                           node_feature_name,
                                           edge_feature_name, pre_transform,
                                           verbose, max_batch_size=1)
    return next(gen)
